# To add a new cell, type '# %%'
# To add a new markdown cell, type '# %% [markdown]'
# %%
import tensorflow as tf
from tensorflow.contrib import slim
import numpy as np
from tqdm import tqdm
from matplotlib import pyplot as plt
import os


# %%
BATCH_SIZE = 512
LR_PRIMAL = 2e-5
LR_DUAL = 1e-4


# %%
def get_data_samples(N):
    data = tf.random_uniform([N], minval=0, maxval=4, dtype=tf.int32)
    return data

def encoder_func(x, eps):
    net = tf.concat([x, eps], axis=-1)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)

    z = slim.fully_connected(net, 2, activation_fn=None)

    return z


def decoder_func(z):
    net = z
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)
    net = slim.fully_connected(net, 64, activation_fn=tf.nn.elu)

    xlogits = slim.fully_connected(net, 4, activation_fn=None)
    return xlogits

def discriminator_func(x, z):
    net = tf.concat([x, z], axis=1)
    net =  slim.fully_connected(net, 256, activation_fn=tf.nn.elu)
    for i in range(5):
        dnet = slim.fully_connected(net, 256, scope='fc_%d_r0' % (i+1))
        net += slim.fully_connected(dnet, 256, activation_fn=None, scope='fc_%d_r1' % (i+1),
                                    weights_initializer=tf.constant_initializer(0.))
        net = tf.nn.elu(net) 

#     net =  slim.fully_connected(net, 512, activation_fn=tf.nn.elu)
    net =  slim.fully_connected(net, 1, activation_fn=None)
    net = tf.squeeze(net, axis=1)
    net += tf.reduce_sum(tf.square(z), axis=1)
    
    return net

def create_scatter(x_test_labels, eps_test, savepath=None):
    plt.figure(figsize=(5,5), facecolor='w')

    for i in range(4):
        z_out = sess.run(z_inferred, feed_dict={x_real_labels: x_test_labels[i], eps: eps_test})
        plt.scatter(z_out[:, 0], z_out[:, 1],  edgecolor='none', alpha=0.5)

    plt.xlim(-3, 3); plt.ylim(-3.5, 3.5)

    plt.axis('off')
    if savepath:
        plt.savefig(savepath, dpi=512)

encoder = tf.make_template('encoder', encoder_func)
decoder = tf.make_template('decoder', decoder_func)
discriminator = tf.make_template('discriminator', discriminator_func)


# %%
eps = tf.random_normal([BATCH_SIZE, 64])
x_real_labels = get_data_samples(BATCH_SIZE)
x_real = tf.one_hot(x_real_labels, 4)
z_sampled = tf.random_normal([BATCH_SIZE, 2])
z_inferred = encoder(x_real, eps)
x_reconstr_logits = decoder(z_inferred)

Tjoint = discriminator(x_real, z_inferred)
Tseperate = discriminator(x_real, z_sampled)

reconstr_err = tf.reduce_sum(
    tf.nn.sigmoid_cross_entropy_with_logits(labels=x_real, logits=x_reconstr_logits),
    axis=1
)

loss_primal = tf.reduce_mean(reconstr_err + Tjoint)
loss_dual = tf.reduce_mean(
    tf.nn.sigmoid_cross_entropy_with_logits(logits=Tjoint, labels=tf.ones_like(Tjoint))
    + tf.nn.sigmoid_cross_entropy_with_logits(logits=Tseperate, labels=tf.zeros_like(Tseperate))
)

optimizer_primal = tf.train.AdamOptimizer(LR_PRIMAL)
optimizer_dual = tf.train.AdamOptimizer(LR_DUAL)

qvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "encoder")
pvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "decoder")
dvars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, "discriminator")

train_op_primal = optimizer_primal.minimize(loss_primal, var_list=pvars+qvars)
train_op_dual = optimizer_dual.minimize(loss_dual, var_list=dvars)


# %%
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())


# %%
x_test_labels = [[i] * BATCH_SIZE for i in range(4)]
eps_test = np.random.randn(BATCH_SIZE, 64) 

outdir = './out_toy'
if not os.path.exists(outdir):
    os.makedirs(outdir)
    
progress = tqdm(range(100000))
for i in progress:
    ELBO_out, _ = sess.run([loss_primal, train_op_primal])
    sess.run(train_op_dual)
    sess.run(train_op_dual)

    progress.set_description('ELBO = %.2f' % ELBO_out)
    if i % 100 == 0:
        create_scatter(x_test_labels, eps_test, savepath=os.path.join(outdir, '%08d.png' % i))
